{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b076bd1a-b236-4fbc-953d-8295b25122ae",
   "metadata": {},
   "source": [
    "# 🤪 WGAN - CelebA Faces"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9235cbd1-f136-411c-88d9-f69f270c0b96",
   "metadata": {},
   "source": [
    "In this notebook, we'll walk through the steps required to train your own Wasserstein GAN on the CelebA faces dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "098731a6-4193-4fda-9018-b14114c54250",
   "metadata": {},
   "source": [
    "The code has been adapted from the excellent [WGAN-GP tutorial](https://keras.io/examples/generative/wgan_gp/) created by Aakash Kumar Nain, available on the Keras website."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84acc7be-6764-4668-b2bb-178f63deeed3",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import numpy as np\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras import (\n",
    "    layers,\n",
    "    models,\n",
    "    callbacks,\n",
    "    utils,\n",
    "    metrics,\n",
    "    optimizers,\n",
    ")\n",
    "\n",
    "from notebooks.utils import display, sample_batch"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "339e6268-ebd7-4feb-86db-1fe7abccdbe5",
   "metadata": {},
   "source": [
    "## 0. Parameters <a name=\"parameters\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b2ee6ce-129f-4833-b0c5-fa567381c4e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "IMAGE_SIZE = 64\n",
    "CHANNELS = 3\n",
    "BATCH_SIZE = 512\n",
    "NUM_FEATURES = 64\n",
    "Z_DIM = 128\n",
    "LEARNING_RATE = 0.0002\n",
    "ADAM_BETA_1 = 0.5\n",
    "ADAM_BETA_2 = 0.999\n",
    "EPOCHS = 200\n",
    "CRITIC_STEPS = 3\n",
    "GP_WEIGHT = 10.0\n",
    "LOAD_MODEL = False\n",
    "ADAM_BETA_1 = 0.5\n",
    "ADAM_BETA_2 = 0.9"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b7716fac-0010-49b0-b98e-53be2259edde",
   "metadata": {},
   "source": [
    "## 1. Prepare the data <a name=\"prepare\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a73e5a4-1638-411c-8d3c-29f823424458",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the data\n",
    "train_data = utils.image_dataset_from_directory(\n",
    "    \"/app/data/celeba-dataset/img_align_celeba/img_align_celeba\",\n",
    "    labels=None,\n",
    "    color_mode=\"rgb\",\n",
    "    image_size=(IMAGE_SIZE, IMAGE_SIZE),\n",
    "    batch_size=BATCH_SIZE,\n",
    "    shuffle=True,\n",
    "    seed=42,\n",
    "    interpolation=\"bilinear\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ebae2f0d-59fd-4796-841f-7213eae638de",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Preprocess the data\n",
    "def preprocess(img):\n",
    "    \"\"\"\n",
    "    Normalize and reshape the images\n",
    "    \"\"\"\n",
    "    img = (tf.cast(img, \"float32\") - 127.5) / 127.5\n",
    "    return img\n",
    "\n",
    "\n",
    "train = train_data.map(lambda x: preprocess(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa53709f-7f3f-483b-9db8-2e5f9b9942c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Show some faces from the training set\n",
    "train_sample = sample_batch(train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b86c15ef-82b2-4a75-99f7-2d8810440403",
   "metadata": {},
   "outputs": [],
   "source": [
    "display(train_sample, cmap=None)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aff50401-3abe-4c10-bba8-b35bc13ad7d5",
   "metadata": {
    "tags": []
   },
   "source": [
    "## 2. Build the WGAN-GP <a name=\"build\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "371eb69d-e534-4666-a412-b5b6fe24689a",
   "metadata": {},
   "outputs": [],
   "source": [
    "critic_input = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, CHANNELS))\n",
    "x = layers.Conv2D(64, kernel_size=4, strides=2, padding=\"same\")(critic_input)\n",
    "x = layers.LeakyReLU(0.2)(x)\n",
    "x = layers.Conv2D(128, kernel_size=4, strides=2, padding=\"same\")(x)\n",
    "x = layers.LeakyReLU()(x)\n",
    "x = layers.Dropout(0.3)(x)\n",
    "x = layers.Conv2D(256, kernel_size=4, strides=2, padding=\"same\")(x)\n",
    "x = layers.LeakyReLU(0.2)(x)\n",
    "x = layers.Dropout(0.3)(x)\n",
    "x = layers.Conv2D(512, kernel_size=4, strides=2, padding=\"same\")(x)\n",
    "x = layers.LeakyReLU(0.2)(x)\n",
    "x = layers.Dropout(0.3)(x)\n",
    "x = layers.Conv2D(1, kernel_size=4, strides=1, padding=\"valid\")(x)\n",
    "critic_output = layers.Flatten()(x)\n",
    "\n",
    "critic = models.Model(critic_input, critic_output)\n",
    "critic.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "086e2584-c60d-4990-89f4-2092c44e023e",
   "metadata": {},
   "outputs": [],
   "source": [
    "generator_input = layers.Input(shape=(Z_DIM,))\n",
    "x = layers.Reshape((1, 1, Z_DIM))(generator_input)\n",
    "x = layers.Conv2DTranspose(\n",
    "    512, kernel_size=4, strides=1, padding=\"valid\", use_bias=False\n",
    ")(x)\n",
    "x = layers.BatchNormalization(momentum=0.9)(x)\n",
    "x = layers.LeakyReLU(0.2)(x)\n",
    "x = layers.Conv2DTranspose(\n",
    "    256, kernel_size=4, strides=2, padding=\"same\", use_bias=False\n",
    ")(x)\n",
    "x = layers.BatchNormalization(momentum=0.9)(x)\n",
    "x = layers.LeakyReLU(0.2)(x)\n",
    "x = layers.Conv2DTranspose(\n",
    "    128, kernel_size=4, strides=2, padding=\"same\", use_bias=False\n",
    ")(x)\n",
    "x = layers.BatchNormalization(momentum=0.9)(x)\n",
    "x = layers.LeakyReLU(0.2)(x)\n",
    "x = layers.Conv2DTranspose(\n",
    "    64, kernel_size=4, strides=2, padding=\"same\", use_bias=False\n",
    ")(x)\n",
    "x = layers.BatchNormalization(momentum=0.9)(x)\n",
    "x = layers.LeakyReLU(0.2)(x)\n",
    "generator_output = layers.Conv2DTranspose(\n",
    "    CHANNELS, kernel_size=4, strides=2, padding=\"same\", activation=\"tanh\"\n",
    ")(x)\n",
    "generator = models.Model(generator_input, generator_output)\n",
    "generator.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88010f20-fb61-498c-b2b2-dac96f6c03b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "class WGANGP(models.Model):\n",
    "    def __init__(self, critic, generator, latent_dim, critic_steps, gp_weight):\n",
    "        super(WGANGP, self).__init__()\n",
    "        self.critic = critic\n",
    "        self.generator = generator\n",
    "        self.latent_dim = latent_dim\n",
    "        self.critic_steps = critic_steps\n",
    "        self.gp_weight = gp_weight\n",
    "\n",
    "    def compile(self, c_optimizer, g_optimizer):\n",
    "        super(WGANGP, self).compile()\n",
    "        self.c_optimizer = c_optimizer\n",
    "        self.g_optimizer = g_optimizer\n",
    "        self.c_wass_loss_metric = metrics.Mean(name=\"c_wass_loss\")\n",
    "        self.c_gp_metric = metrics.Mean(name=\"c_gp\")\n",
    "        self.c_loss_metric = metrics.Mean(name=\"c_loss\")\n",
    "        self.g_loss_metric = metrics.Mean(name=\"g_loss\")\n",
    "\n",
    "    @property\n",
    "    def metrics(self):\n",
    "        return [\n",
    "            self.c_loss_metric,\n",
    "            self.c_wass_loss_metric,\n",
    "            self.c_gp_metric,\n",
    "            self.g_loss_metric,\n",
    "        ]\n",
    "\n",
    "    def gradient_penalty(self, batch_size, real_images, fake_images):\n",
    "        alpha = tf.random.normal([batch_size, 1, 1, 1], 0.0, 1.0)\n",
    "        diff = fake_images - real_images\n",
    "        interpolated = real_images + alpha * diff\n",
    "\n",
    "        with tf.GradientTape() as gp_tape:\n",
    "            gp_tape.watch(interpolated)\n",
    "            pred = self.critic(interpolated, training=True)\n",
    "\n",
    "        grads = gp_tape.gradient(pred, [interpolated])[0]\n",
    "        norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))\n",
    "        gp = tf.reduce_mean((norm - 1.0) ** 2)\n",
    "        return gp\n",
    "\n",
    "    def train_step(self, real_images):\n",
    "        batch_size = tf.shape(real_images)[0]\n",
    "\n",
    "        for i in range(self.critic_steps):\n",
    "            random_latent_vectors = tf.random.normal(\n",
    "                shape=(batch_size, self.latent_dim)\n",
    "            )\n",
    "\n",
    "            with tf.GradientTape() as tape:\n",
    "                fake_images = self.generator(\n",
    "                    random_latent_vectors, training=True\n",
    "                )\n",
    "                fake_predictions = self.critic(fake_images, training=True)\n",
    "                real_predictions = self.critic(real_images, training=True)\n",
    "\n",
    "                c_wass_loss = tf.reduce_mean(fake_predictions) - tf.reduce_mean(\n",
    "                    real_predictions\n",
    "                )\n",
    "                c_gp = self.gradient_penalty(\n",
    "                    batch_size, real_images, fake_images\n",
    "                )\n",
    "                c_loss = c_wass_loss + c_gp * self.gp_weight\n",
    "\n",
    "            c_gradient = tape.gradient(c_loss, self.critic.trainable_variables)\n",
    "            self.c_optimizer.apply_gradients(\n",
    "                zip(c_gradient, self.critic.trainable_variables)\n",
    "            )\n",
    "\n",
    "        random_latent_vectors = tf.random.normal(\n",
    "            shape=(batch_size, self.latent_dim)\n",
    "        )\n",
    "        with tf.GradientTape() as tape:\n",
    "            fake_images = self.generator(random_latent_vectors, training=True)\n",
    "            fake_predictions = self.critic(fake_images, training=True)\n",
    "            g_loss = -tf.reduce_mean(fake_predictions)\n",
    "\n",
    "        gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)\n",
    "        self.g_optimizer.apply_gradients(\n",
    "            zip(gen_gradient, self.generator.trainable_variables)\n",
    "        )\n",
    "\n",
    "        self.c_loss_metric.update_state(c_loss)\n",
    "        self.c_wass_loss_metric.update_state(c_wass_loss)\n",
    "        self.c_gp_metric.update_state(c_gp)\n",
    "        self.g_loss_metric.update_state(g_loss)\n",
    "\n",
    "        return {m.name: m.result() for m in self.metrics}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "edf2f892-9209-42ee-b251-1e7604df5335",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a GAN\n",
    "wgangp = WGANGP(\n",
    "    critic=critic,\n",
    "    generator=generator,\n",
    "    latent_dim=Z_DIM,\n",
    "    critic_steps=CRITIC_STEPS,\n",
    "    gp_weight=GP_WEIGHT,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2f48907-fa82-41b5-8caa-813b2f232c79",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "if LOAD_MODEL:\n",
    "    wgangp.load_weights(\"./checkpoint/checkpoint.ckpt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "35b14665-4359-447b-be58-3fd58ba69084",
   "metadata": {},
   "source": [
    "## 3. Train the GAN <a name=\"train\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b429fdad-ea9c-45a2-a556-eb950d793824",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compile the GAN\n",
    "wgangp.compile(\n",
    "    c_optimizer=optimizers.Adam(\n",
    "        learning_rate=LEARNING_RATE, beta_1=ADAM_BETA_1, beta_2=ADAM_BETA_2\n",
    "    ),\n",
    "    g_optimizer=optimizers.Adam(\n",
    "        learning_rate=LEARNING_RATE, beta_1=ADAM_BETA_1, beta_2=ADAM_BETA_2\n",
    "    ),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c525e44b-b3bb-489c-9d35-fcfe3e714e6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a model save checkpoint\n",
    "model_checkpoint_callback = callbacks.ModelCheckpoint(\n",
    "    filepath=\"./checkpoint/checkpoint.ckpt\",\n",
    "    save_weights_only=True,\n",
    "    save_freq=\"epoch\",\n",
    "    verbose=0,\n",
    ")\n",
    "\n",
    "tensorboard_callback = callbacks.TensorBoard(log_dir=\"./logs\")\n",
    "\n",
    "\n",
    "class ImageGenerator(callbacks.Callback):\n",
    "    def __init__(self, num_img, latent_dim):\n",
    "        self.num_img = num_img\n",
    "        self.latent_dim = latent_dim\n",
    "\n",
    "    def on_epoch_end(self, epoch, logs=None):\n",
    "        random_latent_vectors = tf.random.normal(\n",
    "            shape=(self.num_img, self.latent_dim)\n",
    "        )\n",
    "        generated_images = self.model.generator(random_latent_vectors)\n",
    "        generated_images = generated_images * 127.5 + 127.5\n",
    "        generated_images = generated_images.numpy()\n",
    "        display(\n",
    "            generated_images,\n",
    "            save_to=\"./output/generated_img_%03d.png\" % (epoch),\n",
    "            cmap=None,\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3c497b7-fa40-48df-b2bf-541239cc9400",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "wgangp.fit(\n",
    "    train,\n",
    "    epochs=EPOCHS,\n",
    "    steps_per_epoch=2,\n",
    "    callbacks=[\n",
    "        model_checkpoint_callback,\n",
    "        tensorboard_callback,\n",
    "        ImageGenerator(num_img=10, latent_dim=Z_DIM),\n",
    "    ],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "028138af-d3a5-4134-b980-d3a8a703e70f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save the final models\n",
    "generator.save(\"./models/generator\")\n",
    "critic.save(\"./models/critic\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0765b66b-d12c-42c4-90fa-2ff851a9b3f5",
   "metadata": {},
   "source": [
    "## Generate images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86576e84-afc4-443a-b68d-9a5ee13ce730",
   "metadata": {},
   "outputs": [],
   "source": [
    "z_sample = np.random.normal(size=(10, Z_DIM))\n",
    "imgs = wgangp.generator.predict(z_sample)\n",
    "display(imgs, cmap=None)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "vscode": {
   "interpreter": {
    "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
